This note demonstrates how to - prepare some data in R - use the R package reticulate to pass the data to python - fit a numpyro poisson mixed model - plot the MCMC output using arviz
library(reticulate)
use_condaenv("anaconda3", "/home/jmellor/anaconda3/bin/conda")
matplotlib <- import("matplotlib")
matplotlib$use("Agg", force = TRUE)
N_indivs <- 1000 # num individuals
intervals.mean <- 20
beta <- c(-0.5, 0) # log rate ratio parameters
gamma <- 0.7 # coefficient for effect of true value of biomarker
## y must be integer
## linear mixed model
## fixed effects
## random intercept
## random slope with time
## iid residuals
## poisson model for time to failure
## time-varying effect is the true and unobserved value of the marker at time t.
## simulate individual baseline log hazard rates from a normal distribution
logbaseline <- rnorm(N_indivs, -3, 1)
## simulate individual means for biomarkers from a normal distribution
mean.z <- rnorm(N_indivs, 0, 1)
## simulate fixed covariate
x1.indiv <- rnorm(N_indivs, 0, 1) # fixed covariate
x2.indiv <- rnorm(N_indivs, 0, 1) # fixed covariate
U <- 3 # num covariates including intercept
## generate random number of person-time intervals for each individual
T <- 1 + rpois(n=N_indivs, lambda=intervals.mean)
N <- sum(T)
rate <- exp(logbaseline +
beta[1] * x1.indiv + beta[2] * x2.indiv)
mixdata <- NULL
for(i in 1:N_indivs) { # loop over individuals to simulate from model
mixdata <- rbind(mixdata,
data.frame(indiv=as.factor(rep(i, T[i])),
time=1:T[i],
logbaseline.indiv=rep(logbaseline[i], T[i]),
x1=rep(x1.indiv[i], T[i]),
x2=rep(x2.indiv[i], T[i]),
z.true=rep(mean.z[i], T[i]),
z.obs=rnorm(T[i], mean.z[i]) # simulate measured values
))
}
rate <- with(mixdata, exp(logbaseline.indiv + gamma * z.true +
beta[1] * x1 + beta[2] * x2))
mixdata$y <- rpois(n=N, lambda=rate)
scale_icept <- 10 # SD of prior on coeffs
mean_icept <- 0
X <- model.matrix(object = y ~ x1 + x2, data=mixdata)[, -1]
print(head(X))
## x1 x2
## 1 0.466426 0.5964039
## 2 0.466426 0.5964039
## 3 0.466426 0.5964039
## 4 0.466426 0.5964039
## 5 0.466426 0.5964039
## 6 0.466426 0.5964039
X <- scale(X, center=TRUE, scale=FALSE)
U <- ncol(X)
indiv <- as.integer(mixdata$indiv)
N <- nrow(X)
testdata <- list(N=N, U=U, N_indivs=N_indivs, indiv=indiv,
mean_icept=mean_icept,
scale_icept=scale_icept,
scale_other=3,
y=mixdata$y, X=X)
Here we make use of the r object to access the data.
import warnings
warnings.filterwarnings("ignore")
import numpy as np
def get_data():
data = r.testdata
data['indiv'] = np.array([d-1 for d in data['indiv']])
data['y'] = np.array(data['y'])
data['N_indivs'] = int(data['N_indivs'])
return data
data = get_data()
import jax.numpy as jnp
import os
import time as tm
import jax.random as random
import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.infer import MCMC, NUTS
def model(X, y, mean_icept, scale_icept, scale_other, N_indivs, indiv):
alpha = numpyro.sample('alpha', dist.Normal(mean_icept, scale_icept))
theta = numpyro.sample('theta', dist.Normal(jnp.zeros((X.shape[1], 1)), scale_other))
xi = numpyro.sample('xi', dist.Normal(jnp.zeros(N_indivs), jnp.ones(N_indivs)))
sigma_indiv = numpyro.sample('sigma_indiv', dist.HalfCauchy(5.))
w = alpha + xi*sigma_indiv
eta = X@theta + w[indiv].reshape(-1, 1)
eta = eta.flatten()
with numpyro.plate('data', X.shape[0]):
numpyro.sample('obs', dist.Poisson(jnp.exp(eta)), obs=y)
def fit(data):
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
numpyro.set_platform('gpu')
numpyro.set_host_device_count(4)
nuts = NUTS(model)
mcmc = MCMC(nuts, num_samples=1000, num_warmup=1000, num_chains=4)
mcmc.run(rng_key_,
X=data['X'],
y=data['y'],
mean_icept=data['mean_icept'],
scale_icept=data['scale_icept'],
scale_other=data['scale_other'],
N_indivs=data['N_indivs'],
indiv=data['indiv']
)
return mcmc
mcmc = fit(data)
import arviz as az
import matplotlib.pyplot as plt
inf_data = az.from_numpyro(mcmc)
az.plot_trace(inf_data)
plt.show()